import fitz
import os
import numpy as np
import json
from openai import OpenAI
import re
from tqdm import tqdm

# 设置 API 变量
os.environ["OPENAI_API_KEY"] = "sk-aswwzdvvimeybiiokqebixpkhbmcftlbgkubssfuodifqjcf"
# 初始化OpenAI客户端
client = OpenAI(
    base_url="https://api.siliconflow.cn/v1",
    api_key=os.getenv("OPENAI_API_KEY")  # 获取环境变量
)

def extract_text_from_pdf(pdf_path):
    mypdf = fitz.open(pdf_path)
    all_text = ""
    for page_num in range(mypdf.page_count):
        page = mypdf[page_num]  # Get the page
        text = page.get_text("text")  # Extract text from the page
        all_text += text  # Append the extracted text to the all_text string
    return all_text  # Return the extracted text

def chunk_text(text, chunk_size=800, overlap=0):
    chunks = []
    for i in range(0, len(text), chunk_size - overlap):
        chunk = text[i:i + chunk_size]
        if chunk:  # Ensure we don't add empty chunks
            chunks.append(chunk)
    return chunks

class SimpleVectorStore:
    def __init__(self, dimension=1536):
        self.dimension = dimension
        self.vectors = []
        self.documents = []
        self.metadata = []
    def add_documents(self, documents, vectors=None, metadata=None):
        if vectors is None:
            vectors = [None] * len(documents)
        if metadata is None:
            metadata = [{} for _ in range(len(documents))]
        for doc, vec, meta in zip(documents, vectors, metadata):
            self.documents.append(doc)
            self.vectors.append(vec)
            self.metadata.append(meta)
    def search(self, query_vector, top_k=5):
        if not self.vectors or not self.documents:
            return []
        query_array = np.array(query_vector)
        similarities = []
        for i, vector in enumerate(self.vectors):
            if vector is not None:
                similarity = np.dot(query_array, vector) / (
                        np.linalg.norm(query_array) * np.linalg.norm(vector)
                )
                similarities.append((i, similarity))
        similarities.sort(key=lambda x: x[1], reverse=True)
        results = []
        for i, score in similarities[:top_k]:
            results.append({
                "document": self.documents[i],
                "score": float(score),
                "metadata": self.metadata[i]
            })
        return results
def create_embeddings(texts, model="BAAI/bge-m3"):
    if not texts:
        return []
    batch_size = 100  # Adjust based on your API limits
    all_embeddings = []  # Initialize a list to store all embeddings
    for i in range(0, len(texts), batch_size):
        batch = texts[i:i + batch_size]
        response = client.embeddings.create(
            input=batch,
            model=model
        )
        batch_embeddings = [item.embedding for item in response.data]
        all_embeddings.extend(batch_embeddings)
    return all_embeddings

def process_document(pdf_path, chunk_size=800):
    print("Extracting text from document...")
    text = extract_text_from_pdf(pdf_path)
    print("Chunking text into non-overlapping segments...")
    chunks = chunk_text(text, chunk_size=chunk_size, overlap=0)
    print(f"Created {len(chunks)} chunks")
    print("Generating embeddings for chunks...")
    chunk_embeddings = create_embeddings(chunks)
    vector_store = SimpleVectorStore()
    metadata = [{"chunk_index": i, "source": pdf_path} for i in range(len(chunks))]
    vector_store.add_documents(chunks, chunk_embeddings, metadata)
    doc_info = {
        "chunks": chunks,
        "source": pdf_path,
    }
    return chunks, vector_store, doc_info

def calculate_chunk_values(query, chunks, vector_store, irrelevant_chunk_penalty=0.2):
    query_embedding = create_embeddings([query])[0]
    num_chunks = len(chunks)
    results = vector_store.search(query_embedding, top_k=num_chunks)
    # 生成字典
    # 从 {"metadata": {"chunk_index": 1}, "score": 0.9}
    # {"metadata": {"chunk_index": 2}, "score": 0.8}
    # {"metadata": {"chunk_index": 3}, "score": 0.95}
    # 变 {1: 0.9, 2: 0.8, 3: 0.95}
    relevance_scores = {result["metadata"]["chunk_index"]: result["score"] for result in results}
    chunk_values = []
    # 相关分-惩罚分，如果没有相关分，就赋0.0
    for i in range(num_chunks):
        score = relevance_scores.get(i, 0.0)
        value = score - irrelevant_chunk_penalty
        chunk_values.append(value)
    return chunk_values

def find_best_segments(chunk_values, max_segment_length=20, total_max_length=30, min_segment_value=0.2):
    print("Finding optimal continuous text segments...")
    best_segments = []
    segment_scores = []
    total_included_chunks = 0
    while total_included_chunks < total_max_length:
        best_score = min_segment_value  # Minimum threshold for a segment
        best_segment = None
        for start in range(len(chunk_values)):
            if any(start >= s[0] and start < s[1] for s in best_segments):
                continue
            for length in range(1, min(max_segment_length, len(chunk_values) - start) + 1):
                end = start + length
                if any(end > s[0] and end <= s[1] for s in best_segments):
                    continue
                segment_value = sum(chunk_values[start:end])
                if segment_value > best_score:
                    best_score = segment_value
                    best_segment = (start, end)
        if best_segment:
            best_segments.append(best_segment)
            segment_scores.append(best_score)
            total_included_chunks += best_segment[1] - best_segment[0]
            print(f"Found segment {best_segment} with score {best_score:.4f}")
        else:
            break
    best_segments = sorted(best_segments, key=lambda x: x[0])
    return best_segments, segment_scores
def reconstruct_segments(chunks, best_segments):
    reconstructed_segments = []
    for start, end in best_segments:
        segment_text = " ".join(chunks[start:end])
        reconstructed_segments.append({
            "text": segment_text,
            "segment_range": (start, end),
        })
    return reconstructed_segments  # Return the list of reconstructed text segments
def format_segments_for_context(segments):
    context = []  # Initialize an empty list to store the formatted context
    for i, segment in enumerate(segments):
        segment_header = f"SEGMENT {i + 1} (Chunks {segment['segment_range'][0]}-{segment['segment_range'][1] - 1}):"
        context.append(segment_header)  # Add the segment header to the context list
        context.append(segment['text'])  # Add the segment text to the context list
        context.append("-" * 80)  # Add a separator line for readability
    return "\n\n".join(context)
def generate_response(query, context, model="Qwen/Qwen2-1.5B-Instruct"):
    print("Generating response using relevant segments as context...")
    system_prompt = """You are a helpful assistant that answers questions based on the provided context.
    The context consists of document segments that have been retrieved as relevant to the user's query.
    Use the information from these segments to provide a comprehensive and accurate answer.
    If the context doesn't contain relevant information to answer the question, say so clearly."""
    user_prompt = f"""
        Context:
        {context}
        Question: {query}
        Please provide a helpful answer based on the context provided.
        """
    response = client.chat.completions.create(
        model=model,
        messages=[
            {"role": "system", "content": system_prompt},
            {"role": "user", "content": user_prompt}
        ],
        temperature=0
    )
    return response.choices[0].message.content
# irrelevant_chunk_penalty 无关惩罚
def rag_with_rse(pdf_path, query, chunk_size=800, irrelevant_chunk_penalty=0.2):
    print("\n=== STARTING RAG WITH RELEVANT SEGMENT EXTRACTION ===")
    print(f"Query: {query}")
    chunks, vector_store, doc_info = process_document(pdf_path, chunk_size)
    print("\nCalculating relevance scores and chunk values...")
    chunk_values = calculate_chunk_values(query, chunks, vector_store, irrelevant_chunk_penalty)
    best_segments, scores = find_best_segments(
        chunk_values,
        max_segment_length=20,
        total_max_length=30,
        min_segment_value=0.2
    )
    print("\nReconstructing text segments from chunks...")
    segments = reconstruct_segments(chunks, best_segments)
    context = format_segments_for_context(segments)
    response = generate_response(query, context)
    result = {
        "query": query,
        "segments": segments,
        "response": response
    }
    print("\n=== FINAL RESPONSE ===")
    print(response)
    return result
def standard_top_k_retrieval(pdf_path, query, k=10, chunk_size=800):
    print("\n=== STARTING STANDARD TOP-K RETRIEVAL ===")
    print(f"Query: {query}")
    chunks, vector_store, doc_info = process_document(pdf_path, chunk_size)
    print("Creating query embedding and retrieving chunks...")
    query_embedding = create_embeddings([query])[0]
    results = vector_store.search(query_embedding, top_k=k)
    retrieved_chunks = [result["document"] for result in results]
    context = "\n\n".join([
        f"CHUNK {i + 1}:\n{chunk}"
        for i, chunk in enumerate(retrieved_chunks)
    ])
    response = generate_response(query, context)
    result = {
        "query": query,
        "chunks": retrieved_chunks,
        "response": response
    }
    print("\n=== FINAL RESPONSE ===")
    print(response)
    return result
def evaluate_methods(pdf_path, query, reference_answer=None):
    print("\n========= EVALUATION =========\n")
    rse_result = rag_with_rse(pdf_path, query)
    standard_result = standard_top_k_retrieval(pdf_path, query)
    if reference_answer:
        print("\n=== COMPARING RESULTS ===")
        evaluation_prompt = f"""
            Query: {query}
            Reference Answer:
            {reference_answer}
            Response from Standard Retrieval:
            {standard_result["response"]}
            Response from Relevant Segment Extraction:
            {rse_result["response"]}
            Compare these two responses against the reference answer. Which one is:
            1. More accurate and comprehensive
            2. Better at addressing the user's query
            3. Less likely to include irrelevant information
            Explain your reasoning for each point.
        """
        print("Evaluating responses against reference answer...")
        evaluation = client.chat.completions.create(
            model="Qwen/Qwen2-1.5B-Instruct",
            messages=[
                {"role": "system", "content": "You are an objective evaluator of RAG system responses."},
                {"role": "user", "content": evaluation_prompt}
            ]
        )
        print("\n=== EVALUATION RESULTS ===")
        print(evaluation.choices[0].message.content)
    return {
        "rse_result": rse_result,
        "standard_result": standard_result
    }

with open('data/val.json') as f:
    data = json.load(f)
query = data[0]['question']
reference_answer = data[0]['ideal_answer']
pdf_path = "data/AI_Information.pdf"
results = evaluate_methods(pdf_path, query, reference_answer)

